{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Style Transfer Network\n",
    "In this notebook we will go through the process of converting and evaluating the style transfer model, the one linked in the readme page, to CoreML. This model takes in an image and a style index (one of 26 possible styles) and outputs the stylized image. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first download the TF model (.pb file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the model \n",
    "from __future__ import print_function\n",
    "import coremltools\n",
    "import os,sys\n",
    "import zipfile\n",
    "def download_file_and_unzip(url, dir_path='.'):\n",
    "    \"\"\"Download the frozen TensorFlow model and unzip it.\n",
    "    url - The URL address of the frozen file\n",
    "    dir_path - local directory\n",
    "    \"\"\"\n",
    "    if not os.path.exists(dir_path):\n",
    "        os.makedirs(dir_path)\n",
    "    k = url.rfind('/')\n",
    "    fname = url[k+1:]\n",
    "    fpath = os.path.join(dir_path, fname)\n",
    "\n",
    "    if not os.path.exists(fpath):\n",
    "        if sys.version_info[0] < 3:\n",
    "            import urllib\n",
    "            urllib.urlretrieve(url, fpath)\n",
    "        else:\n",
    "            import urllib.request\n",
    "            urllib.request.urlretrieve(url, fpath)\n",
    "    zip_ref = zipfile.ZipFile(fpath, 'r')\n",
    "    zip_ref.extractall(dir_path)\n",
    "    zip_ref.close()    \n",
    "\n",
    "inception_v1_url = 'https://storage.googleapis.com/download.tensorflow.org/models/stylize_v1.zip'\n",
    "download_file_and_unzip(inception_v1_url)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For conversion to CoreML, we need to find the input and output tensor names in the TF graph. This will also be required to run the TF graph for numerical accuracy check. Lets load the TF graph def and try to find the names. Inputs are generally the tensors that are outputs of the \"Placeholder\" op. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the TF graph definition\n",
    "import tensorflow as tf\n",
    "tf_model_path = './stylize_quantized.pb'\n",
    "with open(tf_model_path, 'rb') as f:\n",
    "    serialized = f.read()\n",
    "tf.reset_default_graph()\n",
    "original_gdef = tf.GraphDef()\n",
    "original_gdef.ParseFromString(serialized)\n",
    "\n",
    "# Lets get some details about a few ops in the beginning and the end of the graph\n",
    "with tf.Graph().as_default() as g:\n",
    "    tf.import_graph_def(original_gdef, name='')\n",
    "    ops = g.get_operations()\n",
    "    N = len(ops)\n",
    "    for i in range(N):\n",
    "        if ops[i].type == 'Placeholder':\n",
    "            for x in ops[i].outputs:\n",
    "                print(\"output name = {}, shape: {},\".format(x.name, x.get_shape())),\n",
    "                print('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two inputs: the image input named \"input:0\" and the style index input named \"style_num:0\". For finding the output lets print some info of the last few ops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.Graph().as_default() as g:\n",
    "    tf.import_graph_def(original_gdef, name='')\n",
    "    ops = g.get_operations()\n",
    "    N = len(ops)\n",
    "    for i in range(N-10,N):\n",
    "        print('\\n\\nop id {} : op type: \"{}\"'.format(str(i), ops[i].type));\n",
    "        print('input(s):'),\n",
    "        for x in ops[i].inputs:\n",
    "            print(\"name = {}, shape: {}, \".format(x.name, x.get_shape())),\n",
    "        print('\\noutput(s):'),\n",
    "        for x in ops[i].outputs:\n",
    "            print(\"name = {}, shape: {},\".format(x.name, x.get_shape())), "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Generally some knowledge about the network may be required to correctly determine the output. In this case the output of the \"Sigmoid\" op is the normalized image (between 0-1) which goes into the \"Mul\" op followed by the \"Squeeze\" op. The final output we are interested in is the tensor \"Squeeze:0\" which is the RGB image with values between 0-255. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now lets convert the model to CoreML. In this particular model, the TF graph can take an image of any size (it will produce the output image of the same size). However, CoreML requires us to specify the exact size of all its inputs. Hence we choose a fixed size for our image. Lets say 256.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tfcoreml\n",
    "mlmodel = tfcoreml.convert(\n",
    "        tf_model_path = tf_model_path,\n",
    "        mlmodel_path = './stylize.mlmodel',\n",
    "        output_feature_names = ['Squeeze:0'],\n",
    "        input_name_shape_dict = {'input:0':[1,256,256,3], 'style_num:0':[26]})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the CoreML model expects two inputs: 'style\\_num_\\_0' which is a multiarray and a sequence of length 26 and 'input_\\_0' which is a multiarray corresponding to the image input and of shape (3,256,256). It produces a multiarray output called 'Squeeze_\\_0'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets now grab an image and using coremltools see what the coreml model predicts.  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import PIL\n",
    "import requests\n",
    "from io import BytesIO\n",
    "from matplotlib.pyplot import imshow\n",
    "# This is an image of a golden retriever from Wikipedia\n",
    "img_url = 'https://upload.wikimedia.org/wikipedia/commons/9/93/Golden_Retriever_Carlos_%2810581910556%29.jpg'\n",
    "response = requests.get(img_url)\n",
    "%matplotlib inline\n",
    "img = PIL.Image.open(BytesIO(response.content))\n",
    "img = img.resize([256,256], PIL.Image.ANTIALIAS)\n",
    "img_np = np.asarray(img).astype(np.float32)\n",
    "print( img_np.shape, img_np.flatten()[:5])\n",
    "imshow(img_np/255.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Transpose the image since CoreML requires C,H,W format (3,256,256)\n",
    "coreml_image_input = np.transpose(img_np, (2,0,1))\n",
    "\n",
    "# The style index is a one-hot vector: a vector of zeros of length 26, with 1 in the index whose style we want\n",
    "index = np.zeros((26)).astype(np.float32)\n",
    "index[0] = 1 #Lets say we want to get style 0\n",
    "\n",
    "# CoreML Multi array interpreation is (Seq, Batch, C,H,W). Hence the style index input, which is a sequence,\n",
    "# must be of shape (26,1,1,1,1)\n",
    "coreml_style_index = index[:,np.newaxis,np.newaxis,np.newaxis,np.newaxis]\n",
    "\n",
    "coreml_input = {'input__0': coreml_image_input, 'style_num__0': coreml_style_index}\n",
    "coreml_out = mlmodel.predict(coreml_input, useCPUOnly = True)['Squeeze__0']\n",
    "print( coreml_out.shape, coreml_out.flatten()[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Transpose back for visualization with imshow\n",
    "coreml_out = np.transpose(np.squeeze(coreml_out), (1,2,0))\n",
    "imshow(coreml_out/255.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That looks cool! Lets try another style. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = np.zeros((26)).astype(np.float32)\n",
    "index[10] = 1 \n",
    "coreml_style_index = index[:,np.newaxis,np.newaxis,np.newaxis,np.newaxis]\n",
    "coreml_input = {'input__0': coreml_image_input, 'style_num__0': coreml_style_index}\n",
    "coreml_out = mlmodel.predict(coreml_input, useCPUOnly = True)['Squeeze__0']\n",
    "coreml_out = np.transpose(np.squeeze(coreml_out), (1,2,0))\n",
    "imshow(coreml_out/255.0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Lets also try to evaluate the same image and style with the TF model to check that the conversion was correct (we should get similar output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_img = np.expand_dims(img_np,axis=0)\n",
    "tf_input_name_image = 'input:0'\n",
    "tf_input_name_style_index = 'style_num:0'\n",
    "feed_dict = {tf_input_name_image: tf_img, tf_input_name_style_index: index}\n",
    "tf_output_name = 'Squeeze:0'\n",
    "with tf.Session(graph = g) as sess:\n",
    "    tf_out = sess.run(tf_output_name, \n",
    "                      feed_dict=feed_dict)\n",
    "imshow(tf_out/255.0)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us look at the input/output description of the CoreML model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(mlmodel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that input \"input_\\_0\" and the output \"Squeeze_\\_0\" are both multiarrays. Since they represent images, it may be more convenient to make them image types. The input can be made of type image by converting again and passing the \"image_input_names\" argument to the convert function call."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlmodel = tfcoreml.convert(\n",
    "        tf_model_path = tf_model_path,\n",
    "        mlmodel_path = './stylize.mlmodel',\n",
    "        output_feature_names = ['Squeeze:0'],\n",
    "        input_name_shape_dict = {'input:0':[1,256,256,3], 'style_num:0':[26]},\n",
    "        image_input_names = ['input:0'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(mlmodel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We see that the input is of type image now. To convert the output type to image, we realize that the mlmodel is a protobuf format and can be edited directly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spec = mlmodel.get_spec()\n",
    "output = spec.description.output[0]\n",
    "print(output.name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from coremltools.proto import FeatureTypes_pb2 as _FeatureTypes_pb2\n",
    "output.type.imageType.colorSpace = _FeatureTypes_pb2.ImageFeatureType.ColorSpace.Value('RGB')\n",
    "output.type.imageType.width = 256\n",
    "output.type.imageType.height = 256\n",
    "print(spec.description)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now both input and output are image. Lets save the spec and call predict again."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coremltools.models.utils.save_spec(spec, './stylize.mlmodel')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mlmodel = coremltools.models.utils._get_model(spec)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coreml_input = {'input__0': img, 'style_num__0': coreml_style_index} #now we can pass in the PIL image\n",
    "coreml_out = mlmodel.predict(coreml_input, useCPUOnly = True)['Squeeze__0'] #coreml_out is also a PIL image\n",
    "imshow(np.asarray(coreml_out).astype(np.float32)/255.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
